{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "| [05_text_matching/03_篇章粒度文本匹配.ipynb](https://github.com/shibing624/nlp-tutorial/tree/main/05_text_matching/03_篇章粒度文本匹配.ipynb)  | LDA主题提取做Doc相似度计算  |[Open In Colab](https://colab.research.google.com/github/shibing624/nlp-tutorial/blob/main/05_text_matching/03_篇章粒度文本匹配.ipynb) |\n",
    "\n",
    "# 篇章文本匹配\n",
    "\n",
    "1. 主题提取，比较doc相似度\n",
    "2. 长文本分类(参考04_文本分类)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from nltk.tokenize import RegexpTokenizer\n",
    "from gensim import corpora, models\n",
    "import gensim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = RegexpTokenizer(r'\\w+')\n",
    "    \n",
    "# create sample documents\n",
    "doc_a = \"Brocolli is good to eat. My brother likes to eat good brocolli, but not my mother.\"\n",
    "doc_b = \"My mother spends a lot of time driving my brother around to baseball practice.\"\n",
    "doc_c = \"Some health experts suggest that driving may cause increased tension and blood pressure.\"\n",
    "doc_d = \"I often feel pressure to perform well at school, but my mother never seems to drive my brother to do better.\"\n",
    "doc_e = \"Health professionals say that brocolli is good for your health.\" \n",
    "\n",
    "# compile sample documents into a list\n",
    "doc_set = [doc_a, doc_b, doc_c, doc_d, doc_e]\n",
    "\n",
    "# list for tokenized documents in loop\n",
    "texts = []\n",
    "\n",
    "# loop through document list\n",
    "for i in doc_set:\n",
    "    # clean and tokenize document string\n",
    "    raw = i.lower()\n",
    "    tokens = tokenizer.tokenize(raw)\n",
    "    # add tokens to list\n",
    "    texts.append(tokens)\n",
    "\n",
    "# turn our tokenized documents into a id <-> term dictionary\n",
    "dictionary = corpora.Dictionary(texts)\n",
    "    \n",
    "# convert tokenized documents into a document-term matrix\n",
    "corpus = [dictionary.doc2bow(text) for text in texts]\n",
    "\n",
    "# generate LDA model\n",
    "ldamodel = gensim.models.ldamodel.LdaModel(corpus, num_topics=2, id2word = dictionary, passes=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, '0.087*\"to\" + 0.087*\"my\" + 0.047*\"brother\" + 0.047*\"mother\"'),\n",
       " (1, '0.074*\"health\" + 0.053*\"that\" + 0.032*\"driving\" + 0.032*\"is\"')]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldamodel.show_topics(num_topics=2, num_words=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, '0.087*\"to\" + 0.087*\"my\" + 0.047*\"brother\"'),\n",
       " (1, '0.074*\"health\" + 0.053*\"that\" + 0.032*\"driving\"')]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldamodel.show_topics(num_topics=3, num_words=3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "保存模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "ldamodel.save('lda.model')\n",
    "del ldamodel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "加载模型并预测："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(0, '0.087*\"to\" + 0.087*\"my\" + 0.047*\"brother\"'),\n",
       " (1, '0.074*\"health\" + 0.053*\"that\" + 0.032*\"driving\"')]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldamodel = gensim.models.ldamodel.LdaModel.load('lda.model')\n",
    "ldamodel.show_topics(num_topics=3, num_words=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.0334055 , 0.0466727 , 0.03333215, 0.03332681, 0.03340471,\n",
       "        0.01991689, 0.01998265, 0.0466696 , 0.08670212, 0.01998262,\n",
       "        0.0867081 , 0.01997586, 0.01997579, 0.0199755 , 0.01984536,\n",
       "        0.01997548, 0.01997629, 0.0199762 , 0.01997623, 0.01997717,\n",
       "        0.00668887, 0.00668821, 0.00668856, 0.00668809, 0.00669494,\n",
       "        0.00668847, 0.00668797, 0.01995114, 0.006689  , 0.00668838,\n",
       "        0.00668861, 0.00669338, 0.01999401, 0.01999387, 0.01999429,\n",
       "        0.01999382, 0.0199941 , 0.01999413, 0.01999382, 0.01999411,\n",
       "        0.01999412, 0.01999333, 0.0199936 , 0.01999399, 0.00669491,\n",
       "        0.00669497, 0.00669473, 0.00669486],\n",
       "       [0.03180207, 0.01068167, 0.01067352, 0.01068201, 0.03180332,\n",
       "        0.03202962, 0.01067961, 0.0106866 , 0.01069377, 0.01067967,\n",
       "        0.01068425, 0.01069043, 0.01069054, 0.01069099, 0.03214349,\n",
       "        0.01069102, 0.01068974, 0.01068988, 0.01068983, 0.01068833,\n",
       "        0.03184251, 0.03184355, 0.031843  , 0.03184375, 0.07432368,\n",
       "        0.03184313, 0.03184393, 0.03197509, 0.0318423 , 0.03184329,\n",
       "        0.03184292, 0.05308075, 0.01066153, 0.01066175, 0.01066109,\n",
       "        0.01066183, 0.01066139, 0.01066134, 0.01066184, 0.01066137,\n",
       "        0.01066136, 0.01066262, 0.01066219, 0.01066157, 0.03183288,\n",
       "        0.03183278, 0.03183317, 0.03183297]], dtype=float32)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldamodel.get_topics()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[([(0.0867081, 'to'),\n",
       "   (0.086702116, 'my'),\n",
       "   (0.0466727, 'brother'),\n",
       "   (0.046669602, 'mother'),\n",
       "   (0.033405498, 'brocolli'),\n",
       "   (0.03340471, 'good'),\n",
       "   (0.033332147, 'but'),\n",
       "   (0.033326812, 'eat'),\n",
       "   (0.019994289, 'do'),\n",
       "   (0.019994127, 'i'),\n",
       "   (0.019994117, 'perform'),\n",
       "   (0.01999411, 'often'),\n",
       "   (0.019994099, 'feel'),\n",
       "   (0.019994013, 'at'),\n",
       "   (0.01999399, 'well'),\n",
       "   (0.019993873, 'better'),\n",
       "   (0.019993823, 'drive'),\n",
       "   (0.019993816, 'never'),\n",
       "   (0.019993596, 'seems'),\n",
       "   (0.019993326, 'school')],\n",
       "  -5.435445407792313),\n",
       " ([(0.074323684, 'health'),\n",
       "   (0.053080745, 'that'),\n",
       "   (0.032143492, 'driving'),\n",
       "   (0.032029618, 'is'),\n",
       "   (0.031975087, 'pressure'),\n",
       "   (0.031843934, 'may'),\n",
       "   (0.031843748, 'experts'),\n",
       "   (0.031843554, 'blood'),\n",
       "   (0.031843293, 'suggest'),\n",
       "   (0.031843133, 'increased'),\n",
       "   (0.031843, 'cause'),\n",
       "   (0.03184292, 'tension'),\n",
       "   (0.031842507, 'and'),\n",
       "   (0.031842303, 'some'),\n",
       "   (0.03183317, 'say'),\n",
       "   (0.031832974, 'your'),\n",
       "   (0.031832885, 'for'),\n",
       "   (0.031832784, 'professionals'),\n",
       "   (0.031803317, 'good'),\n",
       "   (0.03180207, 'brocolli')],\n",
       "  -10.844736713023332)]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldamodel.top_topics(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<gensim.interfaces.TransformedCorpus at 0x7fa158574e50>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ldamodel[corpus]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<gensim.similarities.docsim.MatrixSimilarity at 0x7fa158581820>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "index = gensim.similarities.MatrixSimilarity(ldamodel[corpus])\n",
    "index.save('simIndex.index')\n",
    "index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_text_sim(text):\n",
    "    doc_bow = [dictionary.doc2bow(text) for text in [tokenizer.tokenize(text.lower())]]\n",
    "    print(doc_bow)\n",
    "    vec_lda = ldamodel[doc_bow]\n",
    "    sims = index[vec_lda]\n",
    "\n",
    "    return sims"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[(0, 2), (1, 1), (2, 1), (3, 2), (4, 2), (5, 1), (6, 1), (7, 1), (8, 2), (9, 1), (10, 2)]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.99999994, 0.99999577, 0.07365547, 0.9999525 , 0.08877519]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_text_sim(doc_a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[(0, 1), (4, 1), (5, 1), (24, 2), (31, 1), (44, 1), (45, 1), (46, 1), (47, 1)]]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([[0.08877692, 0.09167764, 0.99988496, 0.07906448, 1.        ]],\n",
       "      dtype=float32)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_text_sim(doc_e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看出`doc_e`距离 主题2 的语义更接近。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.remove('lda.model')\n",
    "os.remove('lda.model.state')\n",
    "os.remove('lda.model.id2word')\n",
    "os.remove('lda.model.expElogbeta.npy')\n",
    "os.remove('simIndex.index')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "本节完。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}